import numpy as np
from multiagent.core import World, Agent, Landmark
from multiagent.scenario import BaseScenario
from scipy.optimize import linear_sum_assignment


def get_thetas(poses):
    # compute angle (0,2pi) from horizontal
    thetas = [None]*len(poses)
    for i in range(len(poses)):
        # (y,x)
        thetas[i] = find_angle(poses[i])
    return thetas


def find_angle(pose):
    # compute angle from horizontal
    angle = np.arctan2(pose[1], pose[0])
    if angle<0:
        angle += 2*np.pi
    return angle


class Scenario(BaseScenario):
    def __init__(self, num_agents=4, dist_threshold=0.1, arena_size=1, identity_size=0):
        self.num_agents = num_agents
        self.target_radius = 0.5 # fixing the target radius for now
        self.ideal_theta_separation = (2*2*np.pi)/self.num_agents # ideal theta difference between two agents
        self.arena_size = arena_size
        self.dist_thres = 0.125
        self.theta_thres = 0.1
        self.identity_size = identity_size
        self.flag = [False,False]
        self.count1 = 0
        self.count2 = 0
        self.count3 = 0
        self.count4 = 0

    def make_world(self):
        world = World()
        # set any world properties first
        world.dim_c = 2
        num_agents = self.num_agents
        num_landmarks = 2
        #num_obstacles = 12
        world.collaborative = False

        # add agents
        world.agents = [Agent(iden=i) for i in range(num_agents)]
        for i, agent in enumerate(world.agents):
            agent.name = 'agent %d' % i
            agent.collide = True
            agent.silent = True
            agent.size = 0.05
            agent.adversary = False

        # add landmarks
        world.landmarks = [Landmark() for i in range(num_landmarks)]
        for i, landmark in enumerate(world.landmarks):
            landmark.name = 'landmark %d' % i
            landmark.collide = False
            landmark.movable = False
            landmark.size = 0.03
        """
        #add obstaclesentities
        world.obstacles = [Obstacle() for i in range(num_obstacles)]
        for i, obstacle in enumerate(world.obstacles):
            obstacle.name = 'obstacle %d' % i
            obstacle.collide = False
            #obstacle.movable = False
            obstacle.size = 0.05
        """
        # make initial conditions
        self.reset_world(world)
        world.dists1 = []
        world.dists2 = []
        return world

    def reset_world(self, world):
        # random properties for agents
        # colors = [np.array([0,0,0.1]), np.array([0,1,0]), np.array([0,0,1]), np.array([1,1,0]), np.array([1,0,0])]
        for i, agent in enumerate(world.agents):
            #agent.color = np.array([0.35, 0.35, 0.85])
            agent.color = np.array([0, 0, 255])
            # agent.color = colors[i]

        # random properties for landmarks
        for i, landmark in enumerate(world.landmarks):
            landmark.color = np.array([0.25, 0.25, 0.25])

        #for i, obstacle in enumerate(world.obstacles):
        #    obstacle.color = np.array([0, 0, 0])

        # set random initial states
        for agent in world.agents:
            agent.state.p_pos = np.random.uniform(-self.arena_size, self.arena_size, world.dim_p)
            agent.state.p_vel = np.zeros(world.dim_p)
            agent.state.c = np.zeros(world.dim_c)
        for i, landmark in enumerate(world.landmarks):
            # bound on the landmark position less than that of the environment for visualization purposes
            landmark.state.p_pos = np.random.uniform(-.8*self.arena_size, .8*self.arena_size, world.dim_p)
            #landmark.state.p_pos = np.zeros(world.dim_p)
            #landmark.state.p_pos[0] = np.random.uniform(-1.5,-1.4)
            #landmark.state.p_pos[1] = np.random.uniform(-1.5,-1.4)
            landmark.state.p_vel = np.zeros(world.dim_p)
        """
        for i, obstacle in enumerate(world.obstacles):
            # bound on the landmark position less than that of the environment for visualization purposes
            obstacle.state.p_pos = np.zeros(world.dim_p)
            if i == 0:
                obstacle.state.p_pos[0] = -1.5
                obstacle.state.p_pos[1] = np.random.uniform(-1.1,-0.8)
            elif i == 1:
                obstacle.state.p_pos[0] = -1.5
                obstacle.state.p_pos[1] = -0.2
            elif i == 2:
                obstacle.state.p_pos[0] = -1.5
                obstacle.state.p_pos[1] = -0.5
            elif i == 3:
                obstacle.state.p_pos[0] = -1.5
                obstacle.state.p_pos[1] = 0.8
            elif i == 4:
                obstacle.state.p_pos[0] = -1.5
                obstacle.state.p_pos[1] = 1.5
            elif i == 5:
                obstacle.state.p_pos[0] = -1.5
                obstacle.state.p_pos[1] = 0.5
            elif i == 6:
                obstacle.state.p_pos[0] = -1.5
                obstacle.state.p_pos[1] = 0.8
            elif i == 7:
                obstacle.state.p_pos[0] = -1.5
                obstacle.state.p_pos[1] = 1.2
            elif i == 8:
                obstacle.state.p_pos[0] = -0.8
                obstacle.state.p_pos[1] = 1.5
            elif i == 9:
                obstacle.state.p_pos[0] = -0.4
                obstacle.state.p_pos[1] = 1.5
            elif i == 10:
                obstacle.state.p_pos[0] = 0
                obstacle.state.p_pos[1] = 1.5
            elif i == 11:
                obstacle.state.p_pos[0] = 0.5
                obstacle.state.p_pos[1] = 1.5
        """

        world.steps = 0
        world.dists1 = []
        world.dists2 = []
        self.flag = [False,False]
        self.count1 = 0
        self.count2 = 0
        self.count3 = 0
        self.count4 = 0

    def reward(self, agent, world):
        #num_agents = self.num_agents*0.5
        reward_hie1 = 0
        reward_hie2 = 0
        agents_tot = [agent.state.p_pos for agent in world.agents]
        agents_tot = np.array(agents_tot)
        #num_agents = len(agents_tot)
        #num_agents = num_agents/2
        num_agents = 5
        #agents1 = agents_tot[0:num_agents1]
        #agents1 = [agents_tot[0],agents_tot[1],agents_tot[2]]
        #agents2 = [agents_tot[3],agents_tot[4],agents_tot[5]]
        #agents2 = agents_tot[num_agents1:num_agents]
        #print(num_agents)

        landmark_pose1 = world.landmarks[0].state.p_pos
        landmark_pose2 = world.landmarks[1].state.p_pos
        #relative_poses1 = [agent1 - landmark_pose1 for agent1 in agents1]
        #relative_poses2 = [agent2 - landmark_pose2 for agent2 in agents2]
        relative_poses1 = [agent.state.p_pos - landmark_pose1 for agent in world.agents]
        relative_poses2 = [agent.state.p_pos - landmark_pose2 for agent in world.agents]
        thetas1 = get_thetas(relative_poses1)
        thetas2 = get_thetas(relative_poses2)
            # anchor at the agent with min theta (closest to the horizontal line)
        theta_min1 = min(thetas1)
        theta_min2 = min(thetas2)
        expected_poses1 = [landmark_pose1 + self.target_radius * np.array(
                          [np.cos(theta_min1 + i1*self.ideal_theta_separation),
                           np.sin(theta_min1 + i1*self.ideal_theta_separation)])
                           for i1 in range(num_agents)]
        expected_poses2 = [landmark_pose2 + self.target_radius * np.array(
                              [np.cos(theta_min2 + i2*self.ideal_theta_separation),
                               np.sin(theta_min2 + i2*self.ideal_theta_separation)])
                               for i2 in range(num_agents)]
        #for i in range(num_agents)
        #expected_poses = np.concatenate((expected_poses1,expected_poses2), axis=0)
        expected_poses = expected_poses1 + expected_poses2
        #print(expected_poses1,'fuck1')
        #print(expected_poses2,'fuck2')
        #print(expected_poses,'fuck3')
        #dists1 = np.array([[np.linalg.norm(a1 - pos1) for pos1 in expected_poses1] for a1 in agents1])
        #dists2 = np.array([[np.linalg.norm(a2 - pos2) for pos2 in expected_poses2] for a2 in agents2])
            # optimal 1:1 agent-landmark pairing (bipartite matching algorithm)
        dists = np.array([[np.linalg.norm(a.state.p_pos - pos) for pos in expected_poses] for a in world.agents])
        #self.delta_dists1 = self._bipartite_min_dists(dists1)
        #self.delta_dists2 = self._bipartite_min_dists(dists2)
        self.delta_dists = self._bipartite_min_dists(dists)
        world.dists = self.delta_dists
        #world.dists1 = self.delta_dists1
        #world.dists2 = self.delta_dists2
        #self.delta_dists_tot = np.concatenate((self.delta_dists1,self.delta_dists2), axis=0)
        total_penalty = np.mean(np.clip(self.delta_dists, 0, 2))
        #total_penalty1 = np.mean(np.clip(self.delta_dists1, 0, 2))
        #total_penalty2 = np.mean(np.clip(self.delta_dists2, 0, 2))


        """
        if np.all(self.delta_dists1 < self.dist_thres):
           self.flag[0] = True
           self.count1 += 1
           if self.count1 < 2:
              total_penalty1 = total_penalty1
           elif self.count1 >=2:
              total_penalty1 = 0
        if np.all(self.delta_dists2 < self.dist_thres):
           self.flag[1] = True
           self.count2 += 1
           if self.count2 < 2:
              total_penalty2 = total_penalty2
           elif self.count2 >=2:
              total_penalty2 = 0
        """
        #total_penalty = total_penalty1 + total_penalty2
        self.joint_reward = -total_penalty
        """
        for i,obstacles in enumerate(world.obstacles):
            if np.linalg.norm(agent.state.p_pos - obstacles.state.p_pos) <= 0.13:
                #print(self.joint_reward,'fuck1')
                self.joint_reward -= 1
        """
        """
        if np.all(self.delta_dists1 < self.dist_thres):
           self.flag[0] = True
           self.count3 += 1
           if self.count3 < 2:
              reward_hie1 = 1
           elif self.count3 >=2:
              reward_hie1 = 0
        if np.all(self.delta_dists2 < self.dist_thres):
           self.flag[1] = True
           self.count4 += 1
           if self.count4 < 2:
              reward_hie2 = 1
           elif self.count4 >=2:
              reward_hie2 = 0
        reward_hie_tot = reward_hie1 + reward_hie2
        self.joint_reward += reward_hie_tot
        """

        return self.joint_reward

    def _bipartite_min_dists(self, dists):
        ri, ci = linear_sum_assignment(dists)
        min_dists = dists[ri, ci]
        return min_dists

    def observation(self, agent, world):
        # positions of all entities in this agent's reference frame
        entity_pos = [entity.state.p_pos - agent.state.p_pos for entity in world.landmarks]
        #print(agent.iden)
        #if agent.iden <= 2:
        #   entity_pos = [world.landmarks[0].state.p_pos - agent.state.p_pos]
        #elif agent.iden > 2 :
        #   entity_pos = [world.landmarks[1].state.p_pos - agent.state.p_pos]
        #obs_pos = [obstacle.state.p_pos - agent.state.p_pos for obstacle in world.obstacles]
        #default_obs = np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + obs_pos)
        default_obs = np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos)
        if self.identity_size != 0:
            identified_obs = np.append(np.eye(self.identity_size)[agent.iden],default_obs)
            return identified_obs
        return default_obs

    def done(self, agent, world):
        condition1 = world.steps >= world.max_steps_episode
        self.is_success = np.all(self.delta_dists < self.dist_thres)
        #if self.flag[0] == True and self.flag[1] == True:
        #    self.is_success = True
        #else:
        #    self.is_success = False
        #self.is_success = False
        return condition1 or self.is_success

    def info(self, agent, world):
        return {'is_success': self.is_success, 'world_steps': world.steps,
                'reward':self.joint_reward, 'dists':self.delta_dists.mean()}
